{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "# Tabular Data Explanation Benchmarking: Xgboost Regression" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "This notebook demonstrates how to use the benchmark utility to benchmark the performance of an explainer for tabular data. In this demo, we showcase explanation performance for [TreeExplainer][treeexplainer_doclink]. The metrics used to evaluate are \"keep positive\" and \"keep negative\". The masker used here is [IndependentMasker][indmasker_doclink] but can also generalize to other tabular maskers. \n", "\n", "The new `benchmark` utility uses the new API with MaskedModel as wrapper around user-imported model and evaluates masked values of inputs.\n", "\n", "[treeexplainer_doclink]: ../../../generated/shap.TreeExplainer.rst#shap.TreeExplainer\n", "[indmasker_doclink]: ../../../generated/shap.maskers.Independent.rst#shap.maskers.Independent" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2023-10-19T18:27:13.000431453Z", "start_time": "2023-10-19T18:27:10.273671927Z" }, "collapsed": false }, "outputs": [], "source": [ "import xgboost\n", "from sklearn.model_selection import train_test_split\n", "\n", "import shap\n", "import shap.benchmark as benchmark" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## Load Data and Model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-10-19T18:27:13.632300747Z", "start_time": "2023-10-19T18:27:13.041684876Z" }, "collapsed": false }, "outputs": [], "source": [ "# create trained model for prediction function\n", "untrained_model = xgboost.XGBRegressor(n_estimators=100, subsample=0.3)\n", "X, y = shap.datasets.california()\n", "X = X.values\n", "\n", "test_size = 0.3\n", "random_state = 0\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)\n", "\n", "model = untrained_model.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## Define Explainer Masker" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2023-10-19T18:27:13.646468289Z", "start_time": "2023-10-19T18:27:13.631831092Z" }, "collapsed": false }, "outputs": [], "source": [ "# use Independent masker as default\n", "masker = shap.maskers.Independent(X)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## Create Explainer Object" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2023-10-19T18:27:13.807047523Z", "start_time": "2023-10-19T18:27:13.639331017Z" }, "collapsed": false }, "outputs": [], "source": [ "# tree explainer is used\n", "explainer = shap.Explainer(model, masker)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## Run SHAP Explanation" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2023-10-19T18:27:51.633979886Z", "start_time": "2023-10-19T18:27:13.808642840Z" }, "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 98%|===================| 20313/20640 [00:38<00:00] " ] } ], "source": [ "shap_values = explainer(X)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## Define Metrics (Sort Order & Perturbation Method)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2023-10-19T18:27:51.635039203Z", "start_time": "2023-10-19T18:27:51.633811769Z" }, "collapsed": false }, "outputs": [], "source": [ "sort_order = \"positive\"\n", "perturbation = \"keep\"" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## Benchmark Explainer" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2023-10-19T18:28:20.607251305Z", "start_time": "2023-10-19T18:27:51.636823115Z" }, "collapsed": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "92b8af2033914663b4129e9938c95760", "version_major": 2, "version_minor": 0 }, "text/plain": "SequentialMasker: 0%| | 0/20640 [00:00" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sp = benchmark._sequential.SequentialPerturbation(explainer.model, explainer.masker, sort_order, perturbation)\n", "sp_result = sp(\"SequentialPerturbation\", shap_values.values, X)\n", "sp.plot(sp_result.curve_x, sp_result.curve_y, sp_result.value)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-10-19T18:28:20.651729706Z", "start_time": "2023-10-19T18:28:20.609001824Z" }, "collapsed": false }, "outputs": [], "source": [ "sort_order = \"negative\"\n", "perturbation = \"keep\"" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-10-19T18:28:50.123076599Z", "start_time": "2023-10-19T18:28:20.649765735Z" }, "collapsed": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6b727d22b664463fb59e747dafd1f37b", "version_major": 2, "version_minor": 0 }, "text/plain": "SequentialMasker: 0%| | 0/20640 [00:00" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sp = benchmark._sequential.SequentialPerturbation(explainer.model, explainer.masker, sort_order, perturbation)\n", "sp_result = sp(\"SequentialPerturbation\", shap_values.values, X)\n", "sp.plot(sp_result.curve_x, sp_result.curve_y, sp_result.value)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 2 }